"""
Phase 3: r-g Dynamics and Demographic Tipping Points
======================================================
Demographics affect both r (aging -> lower r*) and g (aging -> lower growth).
Which effect dominates for debt sustainability?

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
Output: fiscal_dominance/output/tables/phase3_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 3: r-g Dynamics and Demographic Tipping Points")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]

    # =================================================================
    # 1. Z -> rate (Carvalho channel)
    # =================================================================
    dep_var = 'rate_fd'
    rate_controls = [c for c in controls if c != 'kaopen']  # rates don't depend on openness
    vars_1 = demo_vars + rate_controls
    est = df.dropna(subset=[dep_var] + vars_1).copy()

    if len(est) >= 200:
        m1, r1 = fit_and_report(
            est[dep_var].values, est[vars_1].values,
            est['iso3'].values, est['year'].values,
            vars_1, "Model 1: Z -> Rate (Carvalho Channel)"
        )
        all_results.append(r1)

    # =================================================================
    # 2. Z -> growth
    # =================================================================
    dep_var = 'rgdp_growth'
    vars_2 = demo_vars + controls
    est2 = df.dropna(subset=[dep_var] + vars_2).copy()

    if len(est2) >= 200:
        m2, r2 = fit_and_report(
            est2[dep_var].values, est2[vars_2].values,
            est2['iso3'].values, est2['year'].values,
            vars_2, "Model 2: Z -> Real GDP Growth"
        )
        all_results.append(r2)

    # =================================================================
    # 3. Z -> r-g (KEY TEST: which dominates?)
    # =================================================================
    dep_var = 'r_minus_g'
    vars_3 = demo_vars + controls
    est3 = df.dropna(subset=[dep_var] + vars_3).copy()

    if len(est3) >= 200:
        m3, r3 = fit_and_report(
            est3[dep_var].values, est3[vars_3].values,
            est3['iso3'].values, est3['year'].values,
            vars_3, "Model 3: Z -> r-g (KEY TEST)"
        )
        all_results.append(r3)

        print("\n  >>> INTERPRETATION:")
        for zv in demo_vars:
            idx = vars_3.index(zv)
            coef = m3.beta[idx]
            p = m3.pvalues[idx]
            direction = 'r effect dominates (aging raises r-g)' if coef > 0 else 'g effect dominates (aging lowers r-g)'
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else 'ns'
            print(f"    {zv}: {coef:.4f} (p={p:.4f}) {sig} — {direction}")

    # Also with real r-g
    dep_var_real = 'r_minus_g_real'
    if dep_var_real in df.columns:
        est3r = df.dropna(subset=[dep_var_real] + vars_3).copy()
        if len(est3r) >= 200:
            m3r, r3r = fit_and_report(
                est3r[dep_var_real].values, est3r[vars_3].values,
                est3r['iso3'].values, est3r['year'].values,
                vars_3, "Model 3b: Z -> r-g (real rate version)"
            )
            all_results.append(r3r)

    # =================================================================
    # 4. Nonlinear r-g: OADR spline
    # =================================================================
    dep_var = 'r_minus_g'
    spline_vars = ['old_dep', 'oadr_above_15', 'oadr_above_20',
                   'oadr_above_25', 'oadr_above_30']
    spline_vars = [v for v in spline_vars if v in df.columns]
    vars_4 = spline_vars + controls
    est4 = df.dropna(subset=[dep_var] + vars_4).copy()

    if len(est4) >= 200:
        m4, r4 = fit_and_report(
            est4[dep_var].values, est4[vars_4].values,
            est4['iso3'].values, est4['year'].values,
            vars_4, "Model 4: r-g with OADR Spline"
        )
        all_results.append(r4)

        print("\n  >>> OADR SPLINE INTERPRETATION:")
        for sv in spline_vars:
            idx = vars_4.index(sv)
            coef = m4.beta[idx]
            p = m4.pvalues[idx]
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else 'ns'
            print(f"    {sv}: {coef:.4f} (p={p:.4f}) {sig}")

    # =================================================================
    # 5. Pre/post GFC split
    # =================================================================
    for period_label, year_range in [('Pre-GFC (1990-2007)', (1990, 2007)),
                                      ('Post-GFC (2010-2024)', (2010, 2024))]:
        period_df = df[(df['year'] >= year_range[0]) & (df['year'] <= year_range[1])].copy()
        dep_var = 'r_minus_g'
        vars_5 = demo_vars + controls
        est5 = period_df.dropna(subset=[dep_var] + vars_5)

        if len(est5) >= 100:
            m5, r5 = fit_and_report(
                est5[dep_var].values, est5[vars_5].values,
                est5['iso3'].values, est5['year'].values,
                vars_5, f"Model 5: Z -> r-g ({period_label})"
            )
            all_results.append(r5)

    # =================================================================
    # 6. Income group split
    # =================================================================
    if 'high_income' in df.columns:
        for income_label, income_val in [('High Income', 1), ('Low/Middle Income', 0)]:
            inc_df = df[df['high_income'] == income_val].copy()
            dep_var = 'r_minus_g'
            vars_6 = demo_vars + controls
            est6 = inc_df.dropna(subset=[dep_var] + vars_6)

            if len(est6) >= 100:
                m6, r6 = fit_and_report(
                    est6[dep_var].values, est6[vars_6].values,
                    est6['iso3'].values, est6['year'].values,
                    vars_6, f"Model 6: Z -> r-g ({income_label})"
                )
                all_results.append(r6)

    # =================================================================
    # 7. Debt-level interaction: Z x high_debt -> r-g
    # =================================================================
    if 'high_debt' in df.columns:
        df['Z1_x_highdebt'] = df['Z_1'] * df['high_debt']
        df['Z2_x_highdebt'] = df['Z_2'] * df['high_debt']
        df['Z3_x_highdebt'] = df['Z_3'] * df['high_debt']

        debt_int = ['Z1_x_highdebt', 'Z2_x_highdebt', 'Z3_x_highdebt']
        vars_7 = demo_vars + ['high_debt'] + debt_int + controls
        dep_var = 'r_minus_g'
        est7 = df.dropna(subset=[dep_var] + vars_7).copy()

        if len(est7) >= 200:
            m7, r7 = fit_and_report(
                est7[dep_var].values, est7[vars_7].values,
                est7['iso3'].values, est7['year'].values,
                vars_7, "Model 7: Z -> r-g with High Debt Interaction"
            )
            all_results.append(r7)

    # =================================================================
    # 8. Z -> debt change (direct debt dynamics)
    # =================================================================
    dep_var = 'debt_change'
    if dep_var in df.columns and df[dep_var].notna().sum() >= 200:
        vars_8 = demo_vars + ['debt_lag'] + controls
        est8 = df.dropna(subset=[dep_var] + vars_8).copy()

        if len(est8) >= 200:
            m8, r8 = fit_and_report(
                est8[dep_var].values, est8[vars_8].values,
                est8['iso3'].values, est8['year'].values,
                vars_8, "Model 8: Z -> Debt Change (direct dynamics)"
            )
            all_results.append(r8)

    # =================================================================
    # Save all results
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase3_rg_results.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase3_rg_results.csv'}")
        print(f"  {len(results_df)} rows across {results_df['model'].nunique()} models")

        # Summary: Z_1 coefficient across r, g, and r-g models
        print(f"\n{'=' * 70}")
        print("Z_1 COEFFICIENT COMPARISON (rate vs growth vs r-g)")
        print("=" * 70)
        z1_rows = results_df[results_df['variable'] == 'Z_1'][
            ['model', 'coefficient', 'std_error', 'p_value', 'r_squared', 'n_obs']
        ]
        print(z1_rows.to_string(index=False, float_format='%.4f'))

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
